# 导入必要的库
import numpy as np  # 用于处理多维数组和矩阵运算
import faiss  # 用于高效相似性搜索和稠密向量聚类
from util import createXY  # 用于创建数据集的特征和标签
from sklearn.model_selection import train_test_split  # 用于拆分数据集为训练集和测试集
from sklearn.neighbors import KNeighborsClassifier  # sklearn中的K近邻分类器
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
from tqdm import tqdm  # 用于在循环中显示进度条
from FaissKNeighbors import FaissKNeighbors  # 导入自定义的FaissKNeighbors类
import os

# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取命令行参数
def get_args():
    parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。') # 创建命令行参数解析器对象
    parser.add_argument('-m', '--mode', type=str, default='cpu', choices=['cpu', 'gpu'], help='选择训练模式：CPU或GPU。默认 cpu')
    parser.add_argument('-f', '--feature', type=str, default='flat', choices=['flat', 'vgg'], help='选择特征提取方法：flat或vgg。默认 flat')
    parser.add_argument('-l', '--library', type=str, default='sklearn', choices=['sklearn', 'faiss'], help='选择使用的库：sklearn或faiss。默认 sklearn')
    args = parser.parse_args()
    return args

# 主函数，运行训练过程
def main():
    import traceback
    print('main() 启动', flush=True)
    try:
        args = get_args()
        print(f'参数: {args}', flush=True)
        # 根据mode初始化FAISS所需的资源；如当前 faiss 不支持 GPU，则优雅降级为 CPU
        res = None
        if args.mode == 'gpu':
            try:
                if hasattr(faiss, 'StandardGpuResources'):
                    res = faiss.StandardGpuResources()
                else:
                    logging.warning("当前 faiss 未提供 GPU 资源接口，已自动降级为 CPU 模式。建议安装带 GPU 的 faiss 发行版（Windows 下一般不可用）。")
            except Exception as e:
                logging.warning(f"初始化 GPU 资源失败，改用 CPU：{e}")
                res = None
        logging.info(f"选择模式是 {args.mode.upper()}")
        logging.info(f"选择特征提取方法是 {args.feature.upper()}")
        logging.info(f"选择使用的库是 {args.library.upper()}")

        # 载入和预处理数据
        # 使用脚本所在目录作为基准，构造绝对路径，避免因为工作目录不同导致找不到数据
        this_dir = os.path.dirname(__file__)
        train_folder = os.path.join(this_dir, 'data', 'train')
        cache_folder = this_dir  # 将缓存写在项目目录下

        print(f'train_folder={train_folder}', flush=True)
        print(f'cache_folder={cache_folder}', flush=True)

        X, y = createXY(train_folder=train_folder, dest_folder=cache_folder, method=args.feature)
        X = np.array(X).astype('float32')

        # 若数据为空，给出清晰提示并退出
        if X.size == 0:
            logging.error("未在数据目录中找到任何图像，或缓存为空。请检查路径是否存在并包含图片: %s", train_folder)
            print(f"未在数据目录中找到任何图像，或缓存为空。请检查路径是否存在并包含图片: {train_folder}", flush=True)
            with open(os.path.join(cache_folder, 'train_run.log'), 'a', encoding='utf-8') as f:
                f.write(f"[ERROR] 未在数据目录中找到任何图像，或缓存为空。请检查路径是否存在并包含图片: {train_folder}\n")
            return

        # 仅在非空时归一化
        faiss.normalize_L2(X)  # 对数据进行L2归一化
        y = np.array(y)
        logging.info("数据加载和预处理完成。")
        print(f"数据加载完成，X.shape={X.shape}, y.shape={y.shape}", flush=True)

        # 数据集分割为训练集和测试集
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
        logging.info("数据集划分为训练集和测试集。")
        print(f"训练集: {X_train.shape}, 测试集: {X_test.shape}", flush=True)

        # 初始化变量，跟踪最佳的k值和相应的准确率
        best_k = -1
        best_accuracy = 0.0

        # 定义测试的k值范围
        k_values = range(1, 6)
        
        # 根据提供的库选择K近邻算法实现
        KNNClass = FaissKNeighbors if args.library == 'faiss' else KNeighborsClassifier
        logging.info(f"使用的库为: {args.library.upper()}")

        # 遍历k值，训练并评估模型
        for k in tqdm(k_values, desc='寻找最佳k值'):
            print(f"训练 k={k}", flush=True)
            knn = KNNClass(k=k, res=res) if args.library == 'faiss' else KNNClass(n_neighbors=k)
            knn.fit(X_train, y_train)
            accuracy = knn.score(X_test, y_test)
            print(f"k={k} 的准确率: {accuracy}", flush=True)
            # 更新最佳k值和准确率
            if accuracy > best_accuracy:
                best_k = k
                best_accuracy = accuracy

        # 打印结果
        logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')
        print(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}', flush=True)
        with open(os.path.join(cache_folder, 'train_run.log'), 'a', encoding='utf-8') as f:
            f.write(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}\n')
    except Exception as e:
        print('发生异常:', e, flush=True)
        with open(os.path.join(os.path.dirname(__file__), 'train_run.log'), 'a', encoding='utf-8') as f:
            f.write('发生异常:\n')
            f.write(traceback.format_exc())
        raise

# 如果是主脚本，则执行main函数
if __name__ == '__main__':
    main()